{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a61bc9eb-d501-4108-a4ce-425715db2e5f",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "2023-03-10T05:02:36.763263Z",
     "shell.execute_reply.started": "2023-03-10T05:02:35.198899Z",
     "to_execute": "2023-03-10T05:02:35.104Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from trailblazer4ml.causal_representation.datasets.recon_dataset import ReconDataset\n",
    "from trailblazer4ml.causal_representation.representation.recon_base import ReconBase\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a9890e48-f60f-43e8-abdb-920a8d4a7592",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "2023-03-10T05:02:39.962460Z",
     "shell.execute_reply.started": "2023-03-10T05:02:39.951319Z",
     "to_execute": "2023-03-10T05:02:39.851Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8d1c9580-d396-404f-a038-36f1d1851b78",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "2023-03-10T05:03:54.310357Z",
     "shell.execute_reply.started": "2023-03-10T05:03:54.274512Z",
     "to_execute": "2023-03-10T05:03:54.180Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>x0</th>\n",
       "      <th>x1</th>\n",
       "      <th>x2</th>\n",
       "      <th>x3</th>\n",
       "      <th>x4</th>\n",
       "      <th>x5</th>\n",
       "      <th>x6</th>\n",
       "      <th>x7</th>\n",
       "      <th>x8</th>\n",
       "      <th>x9</th>\n",
       "      <th>treatment</th>\n",
       "      <th>y</th>\n",
       "      <th>true_ite</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>-1.904249</td>\n",
       "      <td>-0.295453</td>\n",
       "      <td>0.249461</td>\n",
       "      <td>5.815116</td>\n",
       "      <td>2.511096</td>\n",
       "      <td>-2.792666</td>\n",
       "      <td>-1.026020</td>\n",
       "      <td>1.670825</td>\n",
       "      <td>4.908961</td>\n",
       "      <td>2.077141</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.499994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.022887</td>\n",
       "      <td>-3.529255</td>\n",
       "      <td>-3.311369</td>\n",
       "      <td>-7.325015</td>\n",
       "      <td>-4.481658</td>\n",
       "      <td>-3.161529</td>\n",
       "      <td>3.259369</td>\n",
       "      <td>4.046212</td>\n",
       "      <td>2.875123</td>\n",
       "      <td>8.327909</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.952451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>-4.901207</td>\n",
       "      <td>7.055455</td>\n",
       "      <td>0.310905</td>\n",
       "      <td>0.410686</td>\n",
       "      <td>-1.827519</td>\n",
       "      <td>5.738432</td>\n",
       "      <td>2.005312</td>\n",
       "      <td>3.279442</td>\n",
       "      <td>3.643961</td>\n",
       "      <td>4.973827</td>\n",
       "      <td>1.0</td>\n",
       "      <td>1.0</td>\n",
       "      <td>0.952451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>-6.399633</td>\n",
       "      <td>-0.815524</td>\n",
       "      <td>1.568713</td>\n",
       "      <td>2.184760</td>\n",
       "      <td>-3.212778</td>\n",
       "      <td>3.184206</td>\n",
       "      <td>-0.650805</td>\n",
       "      <td>3.041452</td>\n",
       "      <td>4.122021</td>\n",
       "      <td>7.049287</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.499994</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>-1.998128</td>\n",
       "      <td>1.581029</td>\n",
       "      <td>0.141309</td>\n",
       "      <td>1.901191</td>\n",
       "      <td>-5.506624</td>\n",
       "      <td>0.583806</td>\n",
       "      <td>1.378567</td>\n",
       "      <td>5.379643</td>\n",
       "      <td>-1.844670</td>\n",
       "      <td>1.523271</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.499994</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         x0        x1        x2        x3        x4        x5        x6  \\\n",
       "0 -1.904249 -0.295453  0.249461  5.815116  2.511096 -2.792666 -1.026020   \n",
       "1  7.022887 -3.529255 -3.311369 -7.325015 -4.481658 -3.161529  3.259369   \n",
       "2 -4.901207  7.055455  0.310905  0.410686 -1.827519  5.738432  2.005312   \n",
       "3 -6.399633 -0.815524  1.568713  2.184760 -3.212778  3.184206 -0.650805   \n",
       "4 -1.998128  1.581029  0.141309  1.901191 -5.506624  0.583806  1.378567   \n",
       "\n",
       "         x7        x8        x9  treatment    y  true_ite  \n",
       "0  1.670825  4.908961  2.077141        1.0  1.0  0.499994  \n",
       "1  4.046212  2.875123  8.327909        0.0  0.0  0.952451  \n",
       "2  3.279442  3.643961  4.973827        1.0  1.0  0.952451  \n",
       "3  3.041452  4.122021  7.049287        0.0  0.0  0.499994  \n",
       "4  5.379643 -1.844670  1.523271        0.0  0.0  0.499994  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('../trailblazer/trailblazer4ml/causal_inference/dim-10-datasets/data_x10_1000/10_dim_x_y_ite_1000_sample.csv', index_col=0)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "438ea1bb-2fb5-45a5-bf44-54f323846b2d",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "2023-03-10T05:03:57.888968Z",
     "shell.execute_reply.started": "2023-03-10T05:03:57.883006Z",
     "to_execute": "2023-03-10T05:03:57.786Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "covariates = ['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9']\n",
    "treat = ['treatment']\n",
    "outcomes = ['y']\n",
    "causal_name = ['true_ite']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c94a3a80-d085-4f5f-8676-b00d980a03d6",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "2023-03-10T05:03:58.675109Z",
     "shell.execute_reply.started": "2023-03-10T05:03:58.662400Z",
     "to_execute": "2023-03-10T05:03:58.570Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "data = ReconDataset(df, covariates=covariates, treat=treat, outcomes=outcomes, causal_name=causal_name,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ce3be479-4dd5-4381-8231-a67dc1dc6ae2",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "2023-03-10T05:04:01.029335Z",
     "shell.execute_reply.started": "2023-03-10T05:04:01.004193Z",
     "to_execute": "2023-03-10T05:04:00.911Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data shape: (1000, 13)\n"
     ]
    }
   ],
   "source": [
    "rec = ReconBase(data, device=device)\n",
    "rec.get_Dataloader()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e9fc8ea",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "33e6da7e-b0de-4194-b7e2-6028cad10212",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "2023-03-10T05:04:02.947862Z",
     "to_execute": "2023-03-10T05:04:02.855Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.Recon_fit(reconstructor='VAE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bab0574-a558-4923-93ac-3f4c2bf8bf8f",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "",
     "to_execute": "2023-03-10T05:04:05.308Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3054d54-9911-48ac-b362-88e67140bc02",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "",
     "to_execute": "2023-03-10T05:04:07.218Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.Recon_transform(reconstructor='VAE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d082c334-f4f4-4bb0-a60c-030b7041ed8b",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "",
     "to_execute": "2023-03-10T05:04:10.266Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.output_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5dce234",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6cac77f-797b-4ab8-b843-8127da665ade",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "",
     "to_execute": "2023-03-10T05:04:10.898Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.Recon_fit(reconstructor='AE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dcc956b-b4c6-40e3-aa43-5db0a39f12f9",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "",
     "to_execute": "2023-03-10T05:04:11.284Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.Recon_transform(reconstructor='AE')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85768a7c-3b0b-425a-b9e3-8cfe53ae5afc",
   "metadata": {
    "collapsed": false,
    "execution": {
     "shell.execute_reply.end": "",
     "shell.execute_reply.started": "",
     "to_execute": "2023-03-10T05:04:13.667Z"
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "rec.output_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd9785f-968e-40fc-bd2a-817c6b80258e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
